﻿using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Security;
using System.Security.Cryptography;
using System.Text;

namespace PmSoft.Core.Cryptos
{
	/// <summary>
	/// 提供 RSA 密钥加载和格式转换的工具类。
	/// </summary>
	public static class RsaKeyUtility
	{
		private const string Rsa1Type = "RSA";
		private const string Rsa2Type = "RSA2";
		private const int Rsa1KeyLengthThreshold = 1024;

		/// <summary>
		/// 从字符串加载 RSA 私钥，支持 PKCS1 和 PKCS8 格式。
		/// </summary>
		/// <param name="privateKey">私钥字符串</param>
		/// <param name="keyFormat">密钥格式（"PKCS1" 或 "PKCS8"），默认为 "PKCS8"</param>
		/// <returns>加载的 RSA 加密服务提供者</returns>
		/// <exception cref="ArgumentNullException">当 privateKey 为 null 或空时抛出</exception>
		/// <exception cref="CryptographicException">当密钥格式无效时抛出</exception>
		public static RSACryptoServiceProvider LoadPrivateKey(string privateKey, string keyFormat = "PKCS8")
		{
			ArgumentNullException.ThrowIfNull(privateKey, nameof(privateKey));
			string trimmedKey = NormalizePemKey(privateKey);
			string signType = trimmedKey.Length > Rsa1KeyLengthThreshold ? Rsa2Type : Rsa1Type;

			return keyFormat switch
			{
				"PKCS1" => LoadPrivateKeyFromPkcs1(trimmedKey, signType),
				"PKCS8" => LoadPrivateKeyFromPkcs8(trimmedKey),
				_ => throw new ArgumentException("不支持的密钥格式，仅支持 PKCS1 和 PKCS8", nameof(keyFormat))
			};
		}

		/// <summary>
		/// 从字符串加载 RSA 公钥。
		/// </summary>
		/// <param name="publicKey">公钥字符串</param>
		/// <returns>加载的 RSA 加密服务提供者</returns>
		/// <exception cref="ArgumentNullException">当 publicKey 为 null 或空时抛出</exception>
		public static RSACryptoServiceProvider LoadPublicKey(string publicKey)
		{
			ArgumentNullException.ThrowIfNull(publicKey, nameof(publicKey));
			string trimmedKey = NormalizePemKey(publicKey);
			string xmlKey = ConvertPublicKeyToDotNetXml(trimmedKey);
			RSACryptoServiceProvider rsa = new();
			rsa.FromXmlString(xmlKey);
			return rsa;
		}

		/// <summary>
		/// 从 PKCS1 格式加载私钥。
		/// </summary>
		private static RSACryptoServiceProvider LoadPrivateKeyFromPkcs1(string privateKey, string signType)
		{
			byte[] keyBytes = Convert.FromBase64String(privateKey);
			return ParsePkcs1PrivateKey(keyBytes, signType);
		}

		/// <summary>
		/// 从 PKCS8 格式加载私钥。
		/// </summary>
		private static RSACryptoServiceProvider LoadPrivateKeyFromPkcs8(string privateKey)
		{
			string xmlKey = ConvertPrivateKeyToDotNetXml(privateKey);
			RSACryptoServiceProvider rsa = new();
			rsa.FromXmlString(xmlKey);
			return rsa;
		}

		/// <summary>
		/// 将 Java 格式的 PKCS8 私钥转换为 .NET XML 格式。
		/// </summary>
		private static string ConvertPrivateKeyToDotNetXml(string privateKey)
		{
			RsaPrivateCrtKeyParameters keyParams = (RsaPrivateCrtKeyParameters)PrivateKeyFactory.CreateKey(Convert.FromBase64String(privateKey));
			return $"<RSAKeyValue><Modulus>{Convert.ToBase64String(keyParams.Modulus.ToByteArrayUnsigned())}</Modulus>" +
				   $"<Exponent>{Convert.ToBase64String(keyParams.PublicExponent.ToByteArrayUnsigned())}</Exponent>" +
				   $"<P>{Convert.ToBase64String(keyParams.P.ToByteArrayUnsigned())}</P>" +
				   $"<Q>{Convert.ToBase64String(keyParams.Q.ToByteArrayUnsigned())}</Q>" +
				   $"<DP>{Convert.ToBase64String(keyParams.DP.ToByteArrayUnsigned())}</DP>" +
				   $"<DQ>{Convert.ToBase64String(keyParams.DQ.ToByteArrayUnsigned())}</DQ>" +
				   $"<InverseQ>{Convert.ToBase64String(keyParams.QInv.ToByteArrayUnsigned())}</InverseQ>" +
				   $"<D>{Convert.ToBase64String(keyParams.Exponent.ToByteArrayUnsigned())}</D></RSAKeyValue>";
		}

		/// <summary>
		/// 将 Java 格式的公钥转换为 .NET XML 格式。
		/// </summary>
		private static string ConvertPublicKeyToDotNetXml(string publicKey)
		{
			RsaKeyParameters keyParams = (RsaKeyParameters)PublicKeyFactory.CreateKey(Convert.FromBase64String(publicKey));
			return $"<RSAKeyValue><Modulus>{Convert.ToBase64String(keyParams.Modulus.ToByteArrayUnsigned())}</Modulus>" +
				   $"<Exponent>{Convert.ToBase64String(keyParams.Exponent.ToByteArrayUnsigned())}</Exponent></RSAKeyValue>";
		}

		/// <summary>
		/// 规范化 PEM 密钥字符串，去除头尾标记和换行符。
		/// </summary>
		private static string NormalizePemKey(string key)
		{
			return key.Replace("-----BEGIN RSA PRIVATE KEY-----", "")
					  .Replace("-----END RSA PRIVATE KEY-----", "")
					  .Replace("-----BEGIN PRIVATE KEY-----", "")
					  .Replace("-----END PRIVATE KEY-----", "")
					  .Replace("-----BEGIN PUBLIC KEY-----", "")
					  .Replace("-----END PUBLIC KEY-----", "")
					  .Replace("\r", "")
					  .Replace("\n", "")
					  .Trim();
		}

		/// <summary>
		/// 解析 PKCS1 格式的私钥二进制数据。
		/// </summary>
		private static RSACryptoServiceProvider ParsePkcs1PrivateKey(byte[] keyBytes, string signType)
		{
			using MemoryStream stream = new(keyBytes);
			using BinaryReader reader = new(stream);
			try
			{
				ushort header = reader.ReadUInt16();
				if (header == 0x8130) reader.ReadByte();
				else if (header == 0x8230) reader.ReadInt16();
				else throw new CryptographicException("无效的密钥头格式");

				if (reader.ReadUInt16() != 0x0102 || reader.ReadByte() != 0)
					throw new CryptographicException("无效的密钥版本或填充");

				RSAParameters rsaParams = new()
				{
					Modulus = reader.ReadBytes(GetIntegerSize(reader)),
					Exponent = reader.ReadBytes(GetIntegerSize(reader)),
					D = reader.ReadBytes(GetIntegerSize(reader)),
					P = reader.ReadBytes(GetIntegerSize(reader)),
					Q = reader.ReadBytes(GetIntegerSize(reader)),
					DP = reader.ReadBytes(GetIntegerSize(reader)),
					DQ = reader.ReadBytes(GetIntegerSize(reader)),
					InverseQ = reader.ReadBytes(GetIntegerSize(reader))
				};

				CspParameters cspParams = new() { Flags = CspProviderFlags.UseMachineKeyStore };
				int keySize = signType == Rsa2Type ? 2048 : 1024;
				RSACryptoServiceProvider rsa = new(keySize, cspParams);
				rsa.ImportParameters(rsaParams);
				return rsa;
			}
			catch (Exception ex)
			{
				throw new CryptographicException("解析 PKCS1 私钥失败", ex);
			}
		}

		/// <summary>
		/// 从二进制流中获取整数字段的长度。
		/// </summary>
		private static int GetIntegerSize(BinaryReader reader)
		{
			byte tag = reader.ReadByte();
			if (tag != 0x02) return 0;

			byte lengthTag = reader.ReadByte();
			int size = lengthTag switch
			{
				0x81 => reader.ReadByte(),
				0x82 => BitConverter.ToInt32(new byte[] { reader.ReadByte(), reader.ReadByte(), 0, 0 }, 0),
				_ => lengthTag
			};

			while (reader.ReadByte() == 0) size--;
			reader.BaseStream.Seek(-1, SeekOrigin.Current);
			return size;
		}
	}
}